{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "machine_shape": "hm",
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU",
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "a028c4a1cf0f499f91b6869da083752c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_f55eb0dac23f4d3aba01758a0cea75ea",
              "IPY_MODEL_fb7b3a79dd1842f99d4cfcc5ada8f4d4",
              "IPY_MODEL_87cf8f6df33c4db89e34199aaf8ba5ec"
            ],
            "layout": "IPY_MODEL_65d4420626e04900b60931699d722db7"
          }
        },
        "f55eb0dac23f4d3aba01758a0cea75ea": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_2a9dcfe38d4d4fbb8384e45c3f25e3ee",
            "placeholder": "​",
            "style": "IPY_MODEL_8d4e03351b714120b82c620a43c086bb",
            "value": "Resolving data files: 100%"
          }
        },
        "fb7b3a79dd1842f99d4cfcc5ada8f4d4": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_99ce70b89ccd4260be4aeef6782dd5a9",
            "max": 208,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_7578f4aaad80425cb8ba293487c9d718",
            "value": 208
          }
        },
        "87cf8f6df33c4db89e34199aaf8ba5ec": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_0ecc2c921f0f412ca081d98e14977f35",
            "placeholder": "​",
            "style": "IPY_MODEL_f8f02a4713624bb4b47c671ab666490b",
            "value": " 208/208 [00:00&lt;00:00, 14188.61it/s]"
          }
        },
        "65d4420626e04900b60931699d722db7": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "2a9dcfe38d4d4fbb8384e45c3f25e3ee": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "8d4e03351b714120b82c620a43c086bb": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "99ce70b89ccd4260be4aeef6782dd5a9": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "7578f4aaad80425cb8ba293487c9d718": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "0ecc2c921f0f412ca081d98e14977f35": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "f8f02a4713624bb4b47c671ab666490b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "351eadad22ac4809a88eb78afbdb9a8d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_837ce06401b84abf99801b7519a5ae76",
              "IPY_MODEL_d47b3b40f5c44ba996a7ad1f1061e21e",
              "IPY_MODEL_a9086b91f5c941e0bc0e62d578b0b755"
            ],
            "layout": "IPY_MODEL_a13086d3271f4c82985c29e0fcee81e8"
          }
        },
        "837ce06401b84abf99801b7519a5ae76": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_ccafe205eda24a4f92405eef48404e32",
            "placeholder": "​",
            "style": "IPY_MODEL_d2bc13f77b454397a60041d2cf2868f2",
            "value": "Resolving data files: 100%"
          }
        },
        "d47b3b40f5c44ba996a7ad1f1061e21e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_7d63c2fa1fed4f23b46020b38268164e",
            "max": 208,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_c1e9087f60fb42f2a44962e05492a03c",
            "value": 208
          }
        },
        "a9086b91f5c941e0bc0e62d578b0b755": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_e6f7863011094e2594dabc6e25b577e6",
            "placeholder": "​",
            "style": "IPY_MODEL_cd80623c72b44c1f9801a119254c232f",
            "value": " 208/208 [00:00&lt;00:00, 15313.59it/s]"
          }
        },
        "a13086d3271f4c82985c29e0fcee81e8": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "ccafe205eda24a4f92405eef48404e32": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d2bc13f77b454397a60041d2cf2868f2": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "7d63c2fa1fed4f23b46020b38268164e": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "c1e9087f60fb42f2a44962e05492a03c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "e6f7863011094e2594dabc6e25b577e6": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "cd80623c72b44c1f9801a119254c232f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        }
      }
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "!nvidia-smi"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "OWNcqlzr4Guz",
        "outputId": "b2962bf8-cc95-4d86-bd74-448504abd265"
      },
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Fri Jan  5 10:43:27 2024       \n",
            "+---------------------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |\n",
            "|-----------------------------------------+----------------------+----------------------+\n",
            "| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                                         |                      |               MIG M. |\n",
            "|=========================================+======================+======================|\n",
            "|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   70C    P8              13W /  70W |      0MiB / 15360MiB |      0%      Default |\n",
            "|                                         |                      |                  N/A |\n",
            "+-----------------------------------------+----------------------+----------------------+\n",
            "                                                                                         \n",
            "+---------------------------------------------------------------------------------------+\n",
            "| Processes:                                                                            |\n",
            "|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |\n",
            "|        ID   ID                                                             Usage      |\n",
            "|=======================================================================================|\n",
            "|  No running processes found                                                           |\n",
            "+---------------------------------------------------------------------------------------+\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Importing Libraries"
      ],
      "metadata": {
        "id": "sCK97ecGPwEE"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "Bho0LWg-S6TV"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "from torch.nn import functional as F\n",
        "from datasets import load_dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Checking GPU Availability\n",
        "It would be best to run this on a GPU as running it on a CPU will take a long time and may melt your CPU :)"
      ],
      "metadata": {
        "id": "qPxPsBgcPylz"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "torch.cuda.is_available()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "DTy_t9_ItCEz",
        "outputId": "0cc73ddf-47b2-40e0-d432-e160f7d5f0f2"
      },
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "True"
            ]
          },
          "metadata": {},
          "execution_count": 4
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "torch.cuda.get_device_name()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 56
        },
        "id": "vF2yn76JbATP",
        "outputId": "e382fc4f-fe14-40e1-f9be-b609cfb9251a"
      },
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'Tesla T4'"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            }
          },
          "metadata": {},
          "execution_count": 5
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Hyperparameters from Training"
      ],
      "metadata": {
        "id": "xH7kgxRPP71r"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# hyperparameters\n",
        "batch_size = 16 # how many independent sequences will we process in parallel?\n",
        "block_size = 300 # what is the maximum context length for predictions?\n",
        "max_iters = 10000\n",
        "eval_interval = 500\n",
        "learning_rate = 3e-4\n",
        "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
        "eval_iters = 200\n",
        "n_embd = 1024\n",
        "n_head = 20\n",
        "n_layer = 10\n",
        "dropout = 0.05\n",
        "# ------------\n",
        "\n",
        "torch.manual_seed(1337)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "twfA5lWFsB2m",
        "outputId": "eaf612c0-423a-493b-8e47-83e89b666630"
      },
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<torch._C.Generator at 0x7dd8fafeb9f0>"
            ]
          },
          "metadata": {},
          "execution_count": 6
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Loading our Sitcom Screenplay dataset"
      ],
      "metadata": {
        "id": "g0KgBTDCP-Jl"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "with open('sitcom_pre_training_dataset.txt', 'r', encoding='utf-8') as f:\n",
        "    text = f.read().lower()"
      ],
      "metadata": {
        "id": "szy1DnZesPO3"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Loading BERT Tokenizer"
      ],
      "metadata": {
        "id": "J8gZBgMMQAno"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from transformers import BertTokenizer"
      ],
      "metadata": {
        "id": "J0yZXhF4xpAr"
      },
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "enc = BertTokenizer.from_pretrained(\"bert-base-uncased\")"
      ],
      "metadata": {
        "id": "9PkPQURA5b-k",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "27e7a3a1-06dd-46d0-b740-50df3f942d7d"
      },
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:72: UserWarning: \n",
            "The secret `HF_TOKEN` does not exist in your Colab secrets.\n",
            "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n",
            "You will be able to reuse this secret in all of your notebooks.\n",
            "Please note that authentication is recommended but still optional to access public models or datasets.\n",
            "  warnings.warn(\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "vocab_size = len(enc.vocab)"
      ],
      "metadata": {
        "id": "7aoW9qHbNlBI"
      },
      "execution_count": 9,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "tokens = enc.encode(\"Scene: An empty apartment in New York City\")"
      ],
      "metadata": {
        "id": "gCizIyLsN70X"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "tokens"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Js-fhixS5349",
        "outputId": "f57c6379-849a-4cab-a407-70f458f5ad51"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[101, 3496, 1024, 2019, 4064, 4545, 1999, 2047, 2259, 2103, 102]"
            ]
          },
          "metadata": {},
          "execution_count": 15
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "print(f\"Vocabulary Size: {vocab_size}\")\n",
        "print(f\"Raw Text: Scene: An empty apartment in New York City, Tokens: {tokens}\")\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "1WbmzAfvaOUv",
        "outputId": "e570e3d3-ef65-4297-b2fc-218739115556"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Vocabulary Size: 30522\n",
            "Raw Text: Scene: An empty apartment in New York City, Tokens: [101, 3496, 1024, 2019, 4064, 4545, 1999, 2047, 2259, 2103, 102]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Train-Test Split"
      ],
      "metadata": {
        "id": "x8ne3oOrQEM4"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Train and test splits\n",
        "data = torch.tensor(enc.encode(text), dtype=torch.long)\n",
        "n = int(0.9*len(data)) # first 90% will be train, rest val\n",
        "n = len(data) - n\n",
        "random_part = len(data) // 2\n",
        "train_data1 = data[:random_part]\n",
        "val_data = data[random_part: random_part + n]\n",
        "train_data2 = data[random_part + n:]\n",
        "train_data = torch.hstack([train_data1, train_data2])\n",
        "print(train_data.shape, train_data1.shape, train_data2.shape)"
      ],
      "metadata": {
        "id": "eAEKyFjYsf9a",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "e9e8e748-55e4-42a2-acff-ddfee39a5292"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Token indices sequence length is longer than the specified maximum sequence length for this model (967501 > 512). Running this sequence through the model will result in indexing errors\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "torch.Size([870750]) torch.Size([483750]) torch.Size([387000])\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "print(train_data.shape, val_data.shape)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "A8XcF2JUfkrB",
        "outputId": "9e0b9313-835e-4f49-b035-a3e684755dcd"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "torch.Size([870750]) torch.Size([96751])\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "train_data[-10:]"
      ],
      "metadata": {
        "id": "CvXLFk1jAotV",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "d8ce2b36-ac47-4c43-bce2-031c97a14460"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([7324, 1007, 6864, 1012, 2097, 2017, 5914, 2033, 1029,  102])"
            ]
          },
          "metadata": {},
          "execution_count": 14
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Custom Functions for generating a Batch of Dataset and Computing Losses"
      ],
      "metadata": {
        "id": "nRt1GXl-QGoy"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# data loading\n",
        "def get_batch(split='train'):\n",
        "    # generate a small batch of data of inputs x and targets y\n",
        "    data = train_data if split == 'train' else train_data\n",
        "    ix = torch.randint(len(data) - block_size, (batch_size,))\n",
        "    x = torch.stack([data[i:i+block_size] for i in ix])\n",
        "    y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n",
        "    x, y = x.to(device), y.to(device)\n",
        "    return x, y\n",
        "\n",
        "@torch.no_grad()\n",
        "def estimate_loss():\n",
        "    out = {}\n",
        "    model.eval()\n",
        "    for split in ['train', 'val']:\n",
        "        losses = torch.zeros(eval_iters)\n",
        "        for k in range(eval_iters):\n",
        "            X, Y = get_batch(split)\n",
        "            logits, loss = model(X, Y)\n",
        "            losses[k] = loss.item()\n",
        "        out[split] = losses.mean()\n",
        "    model.train()\n",
        "    return out"
      ],
      "metadata": {
        "id": "NmEdmLtSsiNE"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "X, y = get_batch()"
      ],
      "metadata": {
        "id": "-g7-M1bHiLtE"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "X.shape, y.shape"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "pqBkRU_uiM4Q",
        "outputId": "1120b854-4437-47af-9fd0-0698ee2de8d6"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(torch.Size([16, 300]), torch.Size([16, 300]))"
            ]
          },
          "metadata": {},
          "execution_count": 20
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Language Model's Head Class"
      ],
      "metadata": {
        "id": "DlE9ykYAQKhp"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "class Head(nn.Module):\n",
        "    \"\"\" one head of self-attention \"\"\"\n",
        "\n",
        "    def __init__(self, head_size):\n",
        "        super().__init__()\n",
        "        self.key = nn.Linear(n_embd, head_size, bias=False)\n",
        "        self.query = nn.Linear(n_embd, head_size, bias=False)\n",
        "        self.value = nn.Linear(n_embd, head_size, bias=False)\n",
        "        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))\n",
        "\n",
        "        self.dropout = nn.Dropout(dropout)\n",
        "\n",
        "    def forward(self, x):\n",
        "        # input of size (batch, time-step, channels)\n",
        "        # output of size (batch, time-step, head size)\n",
        "        B,T,C = x.shape\n",
        "        k = self.key(x)   # (B,T,hs)\n",
        "        q = self.query(x) # (B,T,hs)\n",
        "        # compute attention scores (\"affinities\")\n",
        "        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)\n",
        "        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)\n",
        "        wei = F.softmax(wei, dim=-1) # (B, T, T)\n",
        "        wei = self.dropout(wei)\n",
        "        # perform the weighted aggregation of the values\n",
        "        v = self.value(x) # (B,T,hs)\n",
        "        out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)\n",
        "        return out"
      ],
      "metadata": {
        "id": "g67f6Xqfslna"
      },
      "execution_count": 10,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Multi-Headed Attention Class"
      ],
      "metadata": {
        "id": "hPtNV1YDQMrw"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "class MultiHeadAttention(nn.Module):\n",
        "    \"\"\" multiple heads of self-attention in parallel \"\"\"\n",
        "\n",
        "    def __init__(self, num_heads, head_size):\n",
        "        super().__init__()\n",
        "        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])\n",
        "        self.proj = nn.Linear(head_size * num_heads, n_embd)\n",
        "        self.dropout = nn.Dropout(dropout)\n",
        "\n",
        "    def forward(self, x):\n",
        "        out = torch.cat([h(x) for h in self.heads], dim=-1)\n",
        "        out = self.dropout(self.proj(out))\n",
        "        return out"
      ],
      "metadata": {
        "id": "fVzecOHpspJ3"
      },
      "execution_count": 11,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Feed Forward Neural Network Class"
      ],
      "metadata": {
        "id": "GsSJe-s-QO6w"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "class FeedForward(nn.Module):\n",
        "    \"\"\" a simple linear layer followed by a non-linearity \"\"\"\n",
        "\n",
        "    def __init__(self, n_embd):\n",
        "        super().__init__()\n",
        "        self.net = nn.Sequential(\n",
        "            nn.Linear(n_embd, 4 * n_embd),\n",
        "            nn.ReLU(),\n",
        "            nn.Linear(4 * n_embd, n_embd),\n",
        "            nn.Dropout(dropout),\n",
        "        )\n",
        "\n",
        "    def forward(self, x):\n",
        "        return self.net(x)\n"
      ],
      "metadata": {
        "id": "NqgCG6zCss_M"
      },
      "execution_count": 12,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Transformer Block Class"
      ],
      "metadata": {
        "id": "bmfxeg5eQRib"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "class Block(nn.Module):\n",
        "    \"\"\" Transformer block: communication followed by computation \"\"\"\n",
        "\n",
        "    def __init__(self, n_embd, n_head):\n",
        "        # n_embd: embedding dimension, n_head: the number of heads we'd like\n",
        "        super().__init__()\n",
        "        head_size = n_embd // n_head\n",
        "        self.sa = MultiHeadAttention(n_head, head_size)\n",
        "        self.ffwd = FeedForward(n_embd)\n",
        "        self.ln1 = nn.LayerNorm(n_embd)\n",
        "        self.ln2 = nn.LayerNorm(n_embd)\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = x + self.sa(self.ln1(x))\n",
        "        x = x + self.ffwd(self.ln2(x))\n",
        "        return x"
      ],
      "metadata": {
        "id": "NEZYVyajsvCi"
      },
      "execution_count": 13,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Final Language Model Class"
      ],
      "metadata": {
        "id": "mZ1FT052QVO5"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "class GPTLanguageModel(nn.Module):\n",
        "\n",
        "    def __init__(self):\n",
        "        super().__init__()\n",
        "        # each token directly reads off the logits for the next token from a lookup table\n",
        "        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)\n",
        "        self.position_embedding_table = nn.Embedding(block_size, n_embd)\n",
        "        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])\n",
        "        self.ln_f = nn.LayerNorm(n_embd) # final layer norm\n",
        "        self.lm_head = nn.Linear(n_embd, vocab_size)\n",
        "\n",
        "        # better init, not covered in the original GPT video, but important, will cover in followup video\n",
        "        self.apply(self._init_weights)\n",
        "\n",
        "    def _init_weights(self, module):\n",
        "        if isinstance(module, nn.Linear):\n",
        "            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
        "            if module.bias is not None:\n",
        "                torch.nn.init.zeros_(module.bias)\n",
        "        elif isinstance(module, nn.Embedding):\n",
        "            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
        "\n",
        "    def forward(self, idx, targets=None):\n",
        "        B, T = idx.shape\n",
        "\n",
        "        # idx and targets are both (B,T) tensor of integers\n",
        "        tok_emb = self.token_embedding_table(idx) # (B,T,C)\n",
        "        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)\n",
        "        x = tok_emb + pos_emb # (B,T,C)\n",
        "        x = self.blocks(x) # (B,T,C)\n",
        "        x = self.ln_f(x) # (B,T,C)\n",
        "        logits = self.lm_head(x) # (B,T,vocab_size)\n",
        "\n",
        "        if targets is None:\n",
        "            loss = None\n",
        "        else:\n",
        "            B, T, C = logits.shape\n",
        "            logits = logits.view(B*T, C)\n",
        "            targets = targets.view(B*T)\n",
        "            loss = F.cross_entropy(logits, targets)\n",
        "\n",
        "        return logits, loss\n",
        "\n",
        "    def generate(self, idx, max_new_tokens):\n",
        "        # idx is (B, T) array of indices in the current context\n",
        "        for _ in range(max_new_tokens):\n",
        "            # crop idx to the last block_size tokens\n",
        "            idx_cond = idx[:, -block_size:]\n",
        "            # get the predictions\n",
        "            logits, loss = self(idx_cond)\n",
        "            # focus only on the last time step\n",
        "            logits = logits[:, -1, :] # becomes (B, C)\n",
        "            # apply softmax to get probabilities\n",
        "            probs = F.softmax(logits, dim=-1) # (B, C)\n",
        "            # sample from the distribution\n",
        "            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n",
        "            # append sampled index to the running sequence\n",
        "            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n",
        "        return idx\n"
      ],
      "metadata": {
        "id": "YEiNCoedsxBG"
      },
      "execution_count": 14,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Instantiating the Model\n",
        "Use Pre-trained Weights otherwise start with Random Weights"
      ],
      "metadata": {
        "id": "uwtVgHRgQXZi"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "model = GPTLanguageModel()\n",
        "\n",
        "# Model path will be different for our run, as you will be saving it to your\n",
        "# local drive or Google Drive or any other storage solution\n",
        "# Comment out the load_state_dict is weights not present\n",
        "model_path = './drive/MyDrive/sitcom_20_head_10_layers_1024_nembd_20000_steps.pt'\n",
        "print(f\"Loading weights from {model_path}\")\n",
        "model.load_state_dict(torch.load(model_path))\n",
        "model = model.to(device)\n",
        "print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')"
      ],
      "metadata": {
        "id": "M2uIhhzVyqXG",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "6637ad37-0188-41f0-a9c0-9c31d6c0035a"
      },
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Loading weights from ./drive/MyDrive/sitcom_20_head_10_layers_1024_nembd_20000_steps.pt\n",
            "188.616506 M parameters\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Model Training Loop"
      ],
      "metadata": {
        "id": "4p3K4wCxQkI4"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import gc\n",
        "torch.cuda.empty_cache()\n",
        "gc.collect()\n",
        "\n",
        "# create a PyTorch optimizer\n",
        "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n",
        "\n",
        "for iter in range(max_iters):\n",
        "    # every 500 iterations, we evaluate the loss on train and val sets\n",
        "    if iter % eval_interval == 0 or iter == max_iters - 1:\n",
        "        losses = estimate_loss()\n",
        "        print(f\"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}\")\n",
        "\n",
        "    # sample a batch of data\n",
        "    xb, yb = get_batch('train')\n",
        "\n",
        "    # evaluate the loss\n",
        "    logits, loss = model(xb, yb)\n",
        "    optimizer.zero_grad(set_to_none=True)\n",
        "    loss.backward()\n",
        "    optimizer.step()\n"
      ],
      "metadata": {
        "id": "U1xlEfdAs2OJ",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "221e401d-a505-4679-9aa1-443f37a3e47e"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "188.616506 M parameters\n",
            "step 0: train loss 0.1535, val loss 0.1555\n",
            "step 500: train loss 0.1533, val loss 0.1556\n",
            "step 1000: train loss 0.1485, val loss 0.1483\n",
            "step 1500: train loss 0.1438, val loss 0.1433\n",
            "step 2000: train loss 0.1377, val loss 0.1399\n",
            "step 2500: train loss 0.1323, val loss 0.1337\n",
            "step 3000: train loss 0.1298, val loss 0.1291\n",
            "step 3500: train loss 0.1257, val loss 0.1245\n",
            "step 4000: train loss 0.1224, val loss 0.1223\n",
            "step 4500: train loss 0.1189, val loss 0.1193\n",
            "step 5000: train loss 0.1147, val loss 0.1166\n",
            "step 5500: train loss 0.1137, val loss 0.1142\n",
            "step 6000: train loss 0.1124, val loss 0.1113\n",
            "step 6500: train loss 0.1088, val loss 0.1085\n",
            "step 7000: train loss 0.1073, val loss 0.1073\n",
            "step 7500: train loss 0.1057, val loss 0.1055\n",
            "step 8000: train loss 0.1025, val loss 0.1027\n",
            "step 8500: train loss 0.1013, val loss 0.1014\n",
            "step 9000: train loss 0.1001, val loss 0.0998\n",
            "step 9500: train loss 0.0986, val loss 0.0992\n",
            "step 9999: train loss 0.0986, val loss 0.0982\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Generating Text from our Trained Model"
      ],
      "metadata": {
        "id": "7OsBpH6iQmFa"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# generate from the model\n",
        "input_tokens = enc.encode(\"Scene: Sheldon and Penny arguing about Star Wars\")\n",
        "context = torch.tensor([input_tokens],\n",
        "                       device=device)\n",
        "out = enc.decode(model.generate(context, max_new_tokens=500)[0].tolist())\n",
        "# open('more.txt', 'w').write(decode(m.generate(context, max_new_tokens=10000)[0].tolist()))"
      ],
      "metadata": {
        "id": "Q6bQ4TeYtTWC"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import re\n",
        "pattern = r'(\\w+)(\\s+):'\n",
        "modified_text = re.sub(pattern, r'\\n\\1:', text)\n",
        "modified_text"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 229
        },
        "id": "eyf_WZgDXDXU",
        "outputId": "c008f42b-e131-402b-ed58-32ddf2395f8d"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "\"[CLS] scene : sheldon and penny arguing about star wars [SEP] in a dream. sheldon : please, i don't start taking advantage of them. howard : what was i thinking? sheldon : sack of rome a sack's metaphorical kind, but i'm sure that even more of a more likely to develop my birthday at all. penny : oh, come on, i'll put that on fire. hello, howard. howard : raj, you're not gonna run away. sheldon : i couldn't take the compliment. penny : there are so many people here tonight doesn't make us less sales. sheldon : hey, fellas, and she wasn't the only reason i'd like to have a glass of water. penny : yup, i was thinking about that. leonard : i would have maybe head on the krmit to hate you. penny : hey, you don't have to be access to the university who will have to support you? leonard : oh, well, cool. penny : mm - hmm. leonard : at least i know the girlfriend you and i were thinking about investing in stuart's head out thinking about how electric bill? penny : oh, please sit down. sheldon : oh, i've got it. leonard : how do you know that? sheldon : before i tell you. penny : he said they made the world go where it's about to the restaurant. sheldon : if you're hungry, i can teach guy a ice - cream university isn't the perfect joke. amy : you know, i've never grown - up. penny : oh, boy, i'm very sorry. wyatt : i just wanted to go to the moon and get you back to the mall. penny : no. dad : this is getting kind of making tough of bread. susan : well, we're here all the early music. randall : scene's apartment. amy : okay. father : so, what do you expect us to do with this trip? our relationship agreement is an internal environment and see it? randall : outside every single thing is going to be. sheldon : i don't know. i like science, but it's still one of the phasers from the bed. how'd you get that? man : well, it's pretty cool. raj : hey, your friend down party man is as depressing love. lucy : wow. your friend ew. sheldon : come here. i don\""
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            }
          },
          "metadata": {},
          "execution_count": 31
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Your model save path will be different based on the storage solution you are using\n",
        "model_save_path = './drive/MyDrive/sitcom_20_head_10_layers_1024_nembd_20000_steps.pt'\n",
        "torch.save(model.state_dict(), model_save_path)"
      ],
      "metadata": {
        "id": "LEaxVqrxDpda"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Fine Tuning"
      ],
      "metadata": {
        "id": "asXyrEaDQpSv"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Fine Tuning hyperparameters\n",
        "max_iters = 1000\n",
        "eval_interval = 50\n",
        "learning_rate = 1e-4\n",
        "eval_iters = 50"
      ],
      "metadata": {
        "id": "hXBZP5sRLm2Z"
      },
      "execution_count": 17,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import load_dataset\n",
        "\n",
        "def get_batch_fine_tune(split='train'):\n",
        "    # generate a small batch of data of inputs x and targets y\n",
        "    block_size = 300\n",
        "    tbbt_dataset = load_dataset('./tbbt_scenes/', split='train')\n",
        "    output = []\n",
        "    for sample in tbbt_dataset:\n",
        "        text = sample['text']\n",
        "        enc_text = enc.encode(text)\n",
        "        if len(enc_text) <= 300:\n",
        "            enc_text = enc_text + [0] * (301 - len(enc_text))\n",
        "            r = 1\n",
        "        else:\n",
        "            r = len(enc_text) - 300\n",
        "        data = torch.tensor(enc_text, dtype=torch.long)\n",
        "        x = torch.stack([data[i:i+block_size] for i in range(r)])\n",
        "        y = torch.stack([data[i+1:i+block_size+1] for i in range(r)])\n",
        "        output.append((x, y))\n",
        "    return output\n",
        "\n",
        "@torch.no_grad()\n",
        "def estimate_loss_fine_tune():\n",
        "    out = {}\n",
        "    model.eval()\n",
        "    losses = []\n",
        "    for X, Y in get_batch_fine_tune():\n",
        "        X, Y = X.to(device), Y.to(device)\n",
        "        logits, loss = model(X, Y)\n",
        "        losses.append(loss.item())\n",
        "    out['ft'] = sum(losses) / len(losses)\n",
        "    model.train()\n",
        "    return out"
      ],
      "metadata": {
        "id": "KPpU2TtpZ4bS"
      },
      "execution_count": 19,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import gc\n",
        "torch.cuda.empty_cache()\n",
        "gc.collect()\n",
        "\n",
        "# create a PyTorch optimizer\n",
        "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n",
        "\n",
        "for iter in range(max_iters):\n",
        "    # every 500 iterations, we evaluate the loss on train and val sets\n",
        "    if iter % eval_interval == 0 or iter == max_iters - 1:\n",
        "        losses = estimate_loss_fine_tune()\n",
        "        print(f\"step {iter}: Fine Tune loss {losses['ft']:.4f}\")\n",
        "\n",
        "    torch.cuda.empty_cache()\n",
        "    gc.collect()\n",
        "    # sample a batch of data\n",
        "    for X, Y in get_batch_fine_tune():\n",
        "      # evaluate the loss\n",
        "      X, Y = X.to(device), Y.to(device)\n",
        "      logits, loss = model(X, Y)\n",
        "      optimizer.zero_grad(set_to_none=True)\n",
        "      loss.backward()\n",
        "      optimizer.step()\n",
        "      torch.cuda.empty_cache()\n",
        "      gc.collect()\n",
        ""
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 550,
          "referenced_widgets": [
            "a028c4a1cf0f499f91b6869da083752c",
            "f55eb0dac23f4d3aba01758a0cea75ea",
            "fb7b3a79dd1842f99d4cfcc5ada8f4d4",
            "87cf8f6df33c4db89e34199aaf8ba5ec",
            "65d4420626e04900b60931699d722db7",
            "2a9dcfe38d4d4fbb8384e45c3f25e3ee",
            "8d4e03351b714120b82c620a43c086bb",
            "99ce70b89ccd4260be4aeef6782dd5a9",
            "7578f4aaad80425cb8ba293487c9d718",
            "0ecc2c921f0f412ca081d98e14977f35",
            "f8f02a4713624bb4b47c671ab666490b",
            "351eadad22ac4809a88eb78afbdb9a8d",
            "837ce06401b84abf99801b7519a5ae76",
            "d47b3b40f5c44ba996a7ad1f1061e21e",
            "a9086b91f5c941e0bc0e62d578b0b755",
            "a13086d3271f4c82985c29e0fcee81e8",
            "ccafe205eda24a4f92405eef48404e32",
            "d2bc13f77b454397a60041d2cf2868f2",
            "7d63c2fa1fed4f23b46020b38268164e",
            "c1e9087f60fb42f2a44962e05492a03c",
            "e6f7863011094e2594dabc6e25b577e6",
            "cd80623c72b44c1f9801a119254c232f"
          ]
        },
        "id": "vw7POVJr7URQ",
        "outputId": "5e2000a5-9af5-48ed-ea07-b4448eb15b09"
      },
      "execution_count": 20,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Resolving data files:   0%|          | 0/208 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "a028c4a1cf0f499f91b6869da083752c"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "step 0: Fine Tune loss 13.3444\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Resolving data files:   0%|          | 0/208 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "351eadad22ac4809a88eb78afbdb9a8d"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "error",
          "ename": "OutOfMemoryError",
          "evalue": "CUDA out of memory. Tried to allocate 16.00 MiB. GPU 0 has a total capacty of 14.75 GiB of which 9.06 MiB is free. Process 270361 has 14.74 GiB memory in use. Of the allocated memory 14.25 GiB is allocated by PyTorch, and 366.75 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mOutOfMemoryError\u001b[0m                          Traceback (most recent call last)",
            "\u001b[0;32m<ipython-input-20-8c541fab8e15>\u001b[0m in \u001b[0;36m<cell line: 8>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     18\u001b[0m       \u001b[0;31m# evaluate the loss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     19\u001b[0m       \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mY\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mY\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 20\u001b[0;31m       \u001b[0mlogits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mY\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     21\u001b[0m       \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mset_to_none\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     22\u001b[0m       \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1517\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1518\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1519\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1520\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1525\u001b[0m                 \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1526\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1528\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1529\u001b[0m         \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m<ipython-input-14-b7cf258caefa>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, idx, targets)\u001b[0m\n\u001b[1;32m     28\u001b[0m         \u001b[0mpos_emb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mposition_embedding_table\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mT\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# (T,C)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     29\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtok_emb\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mpos_emb\u001b[0m \u001b[0;31m# (B,T,C)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 30\u001b[0;31m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mblocks\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# (B,T,C)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     31\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mln_f\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# (B,T,C)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     32\u001b[0m         \u001b[0mlogits\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlm_head\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# (B,T,vocab_size)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1517\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1518\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1519\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1520\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1525\u001b[0m                 \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1526\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1528\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1529\u001b[0m         \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m    213\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    214\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 215\u001b[0;31m             \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    216\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    217\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1517\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1518\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1519\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1520\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1525\u001b[0m                 \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1526\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1528\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1529\u001b[0m         \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m<ipython-input-13-1936b169da65>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m     12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     13\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msa\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mln1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     15\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mffwd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mln2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     16\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1517\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1518\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1519\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1520\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1525\u001b[0m                 \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1526\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1528\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1529\u001b[0m         \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m<ipython-input-11-c5d95300b17a>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m      9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     10\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m         \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mh\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mheads\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     12\u001b[0m         \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdropout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproj\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     13\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m<ipython-input-11-c5d95300b17a>\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m      9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     10\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m         \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mh\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mheads\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     12\u001b[0m         \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdropout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproj\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     13\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1517\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1518\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1519\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1520\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1525\u001b[0m                 \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1526\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1528\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1529\u001b[0m         \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m<ipython-input-10-94cb707b3c26>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m     18\u001b[0m         \u001b[0mq\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mquery\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# (B,T,hs)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     19\u001b[0m         \u001b[0;31m# compute attention scores (\"affinities\")\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 20\u001b[0;31m         \u001b[0mwei\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mq\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtranspose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m0.5\u001b[0m \u001b[0;31m# (B, T, hs) @ (B, hs, T) -> (B, T, T)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     21\u001b[0m         \u001b[0mwei\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mwei\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmasked_fill\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtril\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mT\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0mT\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'-inf'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# (B, T, T)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     22\u001b[0m         \u001b[0mwei\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msoftmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mwei\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# (B, T, T)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mOutOfMemoryError\u001b[0m: CUDA out of memory. Tried to allocate 16.00 MiB. GPU 0 has a total capacty of 14.75 GiB of which 9.06 MiB is free. Process 270361 has 14.74 GiB memory in use. Of the allocated memory 14.25 GiB is allocated by PyTorch, and 366.75 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "Ph4r5FjkGSlw"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}